// +------------------------------------------------
// | csock.c 
// | hatsusakana@gmail.com 
// +------------------------------------------------

#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#ifdef WIN32
#include <winsock2.h>
typedef int socklen_t;
#else
#include <arpa/inet.h>
#include <sys/ioctl.h>
#include <sys/socket.h>
#include <netdb.h>
#endif

#include "csock.h"

#ifdef USE_OPENSSL
#include <openssl/rsa.h>
#include <openssl/crypto.h>
#include <openssl/x509.h>
#include <openssl/pem.h>
#include <openssl/ssl.h>
#include <openssl/err.h>
#endif

#define CSOCK_BUF_SIZE    256         // 缓存字符长度

// +---------------------------------------------------------------
// | 内部函数 
// +---------------------------------------------------------------

/** 删除末尾回车与换行 **/
static void _cut_enter (char *s) {
    size_t n = 0;

    if (SN(s)) return;

    n = strlen(s) - 1;
    if (s[n] == '\n') s[n] = '\0';

    n = strlen(s) - 1;
    if (s[n] == '\r') s[n] = '\0';
}

// +---------------------------------------------------------------
// | cSOCK 
// +---------------------------------------------------------------

void cSOCK_Init () {
#ifdef WIN32
    WSADATA wsaData;
    WSAStartup(MAKEWORD(2, 0), &wsaData);
#endif
    
#ifdef USE_OPENSSL
    SSL_library_init();
    SSL_load_error_strings();
#endif
}

/** 域名转IP **/
int cSOCK_GetHostByName (const char *doname, char *s, size_t n) {
    struct hostent *hp = NULL;
    
    if (SN(doname) || !s || n == 0)
        return 0;

    if ((hp = gethostbyname(doname)) == NULL)
        return 0;
    
    strncpy(s, inet_ntoa(*(struct in_addr*)hp->h_addr), n);
    return 1;
}

/** 手动优化dns解析速度(从文件读取) **/
/** 例: www.baidu.com 61.135.169.121 **/
int cSOCK_DNS (const char *dnspath, const char *doname, char *s, size_t n) {
    FILE *fp = NULL;
    char temp[CSOCK_BUF_SIZE] = {0};
    size_t len = 0;

    if (SN(dnspath) || SN(doname) || !s || n == 0)
        return 0;

    if ((fp = fopen(dnspath, "r")) == NULL)
        return 0;

    len = strlen(doname);
    while (fgets(temp, CSOCK_BUF_SIZE - 1, fp)) {
        if (*temp == '#') continue;
        if (strncmp(temp, doname, len) == 0) {
            _cut_enter(temp);
            strncpy(s, temp + len + 1, n);
        }
        memset(temp, 0, CSOCK_BUF_SIZE);
    }

    fclose(fp);
    return 0;
}

/** tcp客户端(设置超时时间) **/
int cSOCK_Connect (const char *ip, int port, int timeout) {
    int s = 0;
    struct sockaddr_in sockAddr = {0};
    struct timeval timeVal = {timeout, 0};
    fd_set timeset = {0};
    unsigned long ul = 1, ul2 = 0;
    
#if (defined __WIN32__) || (defined WIN32)
    if ((s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP)) == INVALID_SOCKET)
        return 0;
    
    if (ioctlsocket(s, FIONBIO, (unsigned long*)&ul) == SOCKET_ERROR)
        return 0;
    
    sockAddr.sin_family = AF_INET;
    sockAddr.sin_addr.s_addr = inet_addr(ip);
    sockAddr.sin_port = htons((unsigned short)port);
    
    connect(s, (const struct sockaddr *)&sockAddr, sizeof(sockAddr));
    
    FD_ZERO(&timeset);
    FD_SET((unsigned int)s, &timeset);
    if (select(0, 0, &timeset, 0, &timeVal) <= 0) {
        closesocket(s);
        return 0;
    }
    if (ioctlsocket(s, FIONBIO, (unsigned long*)&ul2) == SOCKET_ERROR) {
        closesocket(s);
        return 0;
    }
#else
    int error = -1, len = sizeof(int), ret = 0;
    
    if ((s = socket(AF_INET, SOCK_STREAM, 0)) < 0)
        return 0;
    
    if (ioctl(s, FIONBIO, &ul) == -1)
        return 0;
    
    sockAddr.sin_family = AF_INET;
    sockAddr.sin_addr.s_addr = inet_addr(ip);
    sockAddr.sin_port = htons((unsigned short)port);
    
    if (connect(s, (struct sockaddr *)&sockAddr, sizeof(sockAddr)) == -1) {
        FD_ZERO(&timeset);
        FD_SET(s, &timeset);
        if (select(s + 1, NULL, &timeset, NULL, &timeVal) > 0) {
            getsockopt(s, SOL_SOCKET, SO_ERROR, &error, (socklen_t *)&len);
            if(error == 0) ret = 1;
            else ret = 0;
        } else ret = 0;
    } else ret = 1;
    
    if (ret == 0 || ioctl(s, FIONBIO, &ul2) == -1) {
        close(s);
        return 0;
    }
#endif
    
    return s;
}

/** 设置发送接受超时 **/
void cSOCK_SetTimeout (int s, int sec) {
#if (defined __WIN32__) || (defined WIN32)
    int timeOutSec = sec * 1000;
    setsockopt(s, SOL_SOCKET, SO_SNDTIMEO, (const char*)&timeOutSec, sizeof(int));
    setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, (const char*)&timeOutSec, sizeof(int));
#else
    struct timeval timeVal = {sec, 0};
    setsockopt(s, SOL_SOCKET, SO_SNDTIMEO, &timeVal, sizeof(timeVal));
    setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, &timeVal, sizeof(timeVal));
#endif
}

/** 发送 **/
size_t cSOCK_Send (int s, const char *data, size_t size) {
    int re = 0, signal = 0;
    size_t result = 0;

#ifdef __linux__
    signal = MSG_NOSIGNAL;
#endif

    re = send(s, data, size, signal);
    while (re > 0) {
        result += (size_t)re;
        if (result >= size) break;
        re = send(s, data + result, size - result, signal);
    }
    
    return result;
}

/** 发送文本 **/
size_t cSOCK_SendTXT (int s, const char *data) {
    if (!data || strlen(data) == 0) return 0;
    return cSOCK_Send(s, data, strlen(data));
}

/** 接收 **/
size_t cSOCK_Recv (int s, char *data, size_t size) {
    int re = 0;
    size_t result = 0;
    
    re = recv(s, data, size, 0);
    while (re > 0) {
        result += (size_t)re;
        if (result >= size) break;
        re = recv(s, data + result, size - result, 0);
    }
    
    return result;
}

/** 读取一行 **/
int cSOCK_ReadLine (int s, char *line, size_t size) {
    char c = '\0';
    size_t i = 0;

    if (!line || size == 0) return 0;

    while (i < size && cSOCK_Recv(s, &c, 1) == 1) {
        line[i++] = c;
        if (c == '\n') break;
    }

    return c == '\n';
}

/** 关闭套接字 **/
void cSOCK_Close (int s) {
#if (defined __WIN32__) || (defined WIN32)
    closesocket(s);
#else
    close(s);
#endif
}

/** 关闭套接字(立刻) **/
void cSOCK_CloseNow (int s) {
    struct linger lin = {1, 0};
    setsockopt(s, SOL_SOCKET, SO_LINGER, (const char*)&lin, sizeof(struct linger));
    
#if (defined __WIN32__) || (defined WIN32)
    closesocket(s);
#else
    close(s);
#endif
}

// +-------------------------------------------------------
// | USE OpenSSL 
// +-------------------------------------------------------

#ifdef USE_OPENSSL

struct SSLSocket {
    SSL *ssl;
    SSL_CTX *ctx;
};

/** OpenSSL握手 **/
SSLSocket *cSOCK_SSL_Connect (int s) {
    SSLSocket *ssls = NULL;

    ssls = (SSLSocket *)malloc(sizeof(SSLSocket));
    if (!ssls) return NULL;

    ssls->ctx = SSL_CTX_new(SSLv23_client_method());
    if (ssls->ctx == NULL) {
        free(ssls);
        return NULL;
    }

    ssls->ssl = SSL_new(ssls->ctx);
    SSL_set_fd(ssls->ssl, s);
    if (SSL_connect(ssls->ssl) == -1) {
        SSL_free(ssls->ssl);
        SSL_CTX_free(ssls->ctx);
        free(ssls);
        return NULL;
    }

    return ssls;
}

/** OpenSSL发送 **/
size_t cSOCK_SSL_Send (SSLSocket *ssls, const char *data, size_t size) {
    int re = 0;
    size_t result = 0;

    re = SSL_write(ssls->ssl, data, size);
    while (re > 0) {
        result += (size_t)re;
        if (result >= size) break;
        re = SSL_write(ssls->ssl, data + result, size - result);
    }

    return result;
}

/** OpenSSL接收 */
size_t cSOCK_SSL_Recv (SSLSocket *ssls, char *data, size_t size) {
    int re = 0;
    size_t result = 0;

    re = SSL_read(ssls->ssl, data, size);
    while (re > 0) {
        result += (size_t)re;
        if (result >= size) break;
        re = SSL_read(ssls->ssl, data + result, size - result);
    }

    return result;
}

/** OpenSSL读取一行 **/
int cSOCK_SSL_ReadLine (SSLSocket *ssls, char *line, size_t size) {
    char c = '\0';
    size_t i = 0;

    if (!line || size == 0) return 0;

    while (i < size && cSOCK_SSL_Recv(ssls, &c, 1) == 1) {
        line[i++] = c;
        if (c == '\n') break;
    }

    return c == '\n';
}

/** OpenSSL关闭 **/
void cSOCK_SSL_Close(SSLSocket **ssls) {
    if (ssls == NULL || *ssls == NULL)
        return;

    SSL_shutdown((*ssls)->ssl);
    SSL_free((*ssls)->ssl);
    SSL_CTX_free((*ssls)->ctx);
    free(*ssls);
    *ssls = NULL;
}

#endif

/** 创建一个TCP服务器 **/
int csockserver (int port, void *p, void (*cb)(void *, int)) {
    int serverSock = 0;
    struct sockaddr_in serverSockAddr = {0};
    int clientSock = 0, clientSockAddrSize = 0;
    struct sockaddr_in clientSockAddr = {0};
    
    if (port <= 0 || !cb) return -1;
    
#ifdef WIN32
    WSADATA wsaData;
    WSAStartup(MAKEWORD(2, 0), &wsaData);
#endif
    
#if (defined __WIN32__) || (defined WIN32)
    serverSock = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
#else
    serverSock = socket(AF_INET, SOCK_STREAM, 0);
#endif
    
    if (serverSock == 0) return -2;
    
    serverSockAddr.sin_family = AF_INET;
    serverSockAddr.sin_port = htons((unsigned short)port);
    serverSockAddr.sin_addr.s_addr = INADDR_ANY;
    if (bind(serverSock, (struct sockaddr *)&serverSockAddr, sizeof(serverSockAddr)) == -1)
        return -3;
    if (listen(serverSock, 1024) == -1)
        return -4;
    while (1) {
        clientSockAddrSize = sizeof(struct sockaddr_in);
        clientSock = accept(serverSock, (struct sockaddr *)&clientSockAddr, &clientSockAddrSize);
        if (clientSock == -1) break;
        cb(p, clientSock);
    }
    return 0;
}
